import random
import torch
from typing import List, Dict
import numpy as np
import argparse
import json
import os
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoTokenizer
import faiss
import faiss.contrib.torch_utils
import pickle

class FaissRAGReasoningPreprocessor:
    def __init__(self, data: List[Dict], embedder, tokenizer, vocab_indices: List[int], 
                 neg_token_length=20, neg_strategy="random_vocab", anchor_source="answer", cache_dir="./cache", wiki_dir="/workspace/0407_nips/embd_train/train/0_mk_data_fixed/wiki_rag"):
        """
        Args:
            data: list of dicts with 'answer', 'reasoning', 'reasoning_sentences'
            embedder: embedding model to encode text
            tokenizer: tokenizer used to convert indices to tokens
            vocab_indices: list of token ids to sample negatives from
            neg_token_length: number of tokens to sample for each negative
            neg_strategy: negative sampling strategy
            anchor_source: source of anchor text
            cache_dir: directory to store cached embeddings and indices
        """
        self.data = data
        self.embedder = embedder
        self.tokenizer = tokenizer
        self.vocab_indices = vocab_indices
        self.neg_token_length = neg_token_length
        self.neg_strategy = neg_strategy
        self.anchor_source = anchor_source
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)

        if self.neg_strategy == "vocab_rag":
            self._build_vocab_rag()
        if self.neg_strategy == "sample_rag":
            self._build_sample_rag()
        if self.neg_strategy == "external_rag":
            self._load_external_rag(wiki_dir)

    def _load_external_rag(self, dir):
        print("🔵 Loading external RAG index and texts from FAISS...")

        rag_dir = os.path.join(dir, "rag_index")
        index_path = os.path.join(rag_dir, "index.faiss")
        texts_path = os.path.join(rag_dir, "texts.pkl")

        cpu_index = faiss.read_index(index_path)

        if faiss.get_num_gpus() > 0:
            print("🚀 Transferring FAISS index to GPU...")
            res = faiss.StandardGpuResources()
            self.external_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
        else:
            print("⚠️  No GPU detected, using CPU index.")
            self.external_index = cpu_index

        with open(texts_path, "rb") as f:
            self.external_texts = pickle.load(f)
        
        print(f"🔵 Loaded external FAISS index with {self.external_index.ntotal} entries.")
        print(f"🔵 Loaded {len(self.external_texts)} external texts.")

    def _compute_embeddings(self, texts: List[str]) -> np.ndarray:
        """Batch encode texts into embeddings."""
        embeddings = self.embedder.encode(texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True)
        return np.array(embeddings).astype(np.float32)

    def _build_faiss_index(self, embeddings: np.ndarray):
        """Build FAISS index, transfer to GPU if available."""
        d = embeddings.shape[1]
        cpu_index = faiss.IndexFlatIP(d)
        faiss.normalize_L2(embeddings)
        cpu_index.add(embeddings)

        if faiss.get_num_gpus() > 0:
            print("🚀 Using GPU FAISS index")
            gpu_res = faiss.StandardGpuResources()
            gpu_index = faiss.index_cpu_to_gpu(gpu_res, 0, cpu_index)
            return gpu_index
        else:
            print("⚠️  Using CPU FAISS index")
            return cpu_index
        
    def _build_vocab_rag(self):
        """Precompute vocab token embeddings and build FAISS index with caching."""
        vocab_embeds_path = os.path.join(self.cache_dir, "vocab_embeds.npy")
        vocab_index_path = os.path.join(self.cache_dir, "vocab_index.faiss")

        # if os.path.exists(vocab_embeds_path) and os.path.exists(vocab_index_path):
        #     print(f"🔵 Loading cached vocab embeddings and FAISS index...")
        #     self.vocab_embeddings = np.load(vocab_embeds_path)
        #     self.vocab_index = faiss.read_index(vocab_index_path)
        # else:
        print(f"🔵 Building vocab embeddings and FAISS index from scratch...")
        vocab_tokens = [self.tokenizer.decode([idx], skip_special_tokens=True) for idx in self.vocab_indices]
        self.vocab_embeddings = self._compute_embeddings(vocab_tokens)
        self.vocab_index = self._build_faiss_index(self.vocab_embeddings)
            # np.save(vocab_embeds_path, self.vocab_embeddings)
            # faiss.write_index(self.vocab_index, vocab_index_path)

    def _build_sample_rag(self):
        """Precompute answer embeddings and build FAISS index with caching."""
        answer_embeds_path = os.path.join(self.cache_dir, "answer_embeds.npy")
        answer_index_path = os.path.join(self.cache_dir, "answer_index.faiss")

        # if os.path.exists(answer_embeds_path) and os.path.exists(answer_index_path):
        #     print(f"🔵 Loading cached answer embeddings and FAISS index...")
        #     self.answer_embeddings = np.load(answer_embeds_path)
        #     self.index = faiss.read_index(answer_index_path)
        # else:
        print(f"🔵 Building answer embeddings and FAISS index from scratch...")
        self.answer_embeddings = self._compute_embeddings([item["answer"] for item in self.data])
        self.index = self._build_faiss_index(self.answer_embeddings)
            # np.save(answer_embeds_path, self.answer_embeddings)
            # faiss.write_index(self.index, answer_index_path)

    def _construct_hard_negative(
        self, pos_text: str, neg_text: str, max_length: int = 128
    ) -> Dict[str, str]:

        r = random.uniform(0.2, 0.8)
        pos_tokens = self.tokenizer.encode(pos_text, add_special_tokens=False)
        neg_tokens = self.tokenizer.encode(neg_text, add_special_tokens=False)

        pos_len = int(r * max_length)
        neg_len = max_length - pos_len

        if len(pos_tokens) > pos_len:
            start = random.randint(0, len(pos_tokens) - pos_len)
            pos_tokens = pos_tokens[start:start + pos_len]
        else:
            pos_tokens = pos_tokens[:pos_len]

        if len(neg_tokens) > neg_len:
            start = random.randint(0, len(neg_tokens) - neg_len)
            neg_tokens = neg_tokens[start:start + neg_len]
        else:
            neg_tokens = neg_tokens[:neg_len]

        insert_pos = random.randint(0, len(pos_tokens))
        mixed_tokens = pos_tokens[:insert_pos] + neg_tokens + pos_tokens[insert_pos:]
        mixed_tokens = mixed_tokens[:max_length]

        mixed_text = self.tokenizer.decode(mixed_tokens, skip_special_tokens=True)

        return {
            "text": mixed_text,
            "meta": {
                "ratio": round(r, 4),
                "insert_pos": insert_pos,
                "pos_text": self.tokenizer.decode(pos_tokens, skip_special_tokens=True),
                "neg_text": self.tokenizer.decode(neg_tokens, skip_special_tokens=True)
            }
        }



    def _sample_negative_text(self, anchor_text: str, anchor_idx: int) -> str:

        if self.neg_strategy == "random_vocab":
            sampled_token_ids = random.choices(self.vocab_indices, k=self.neg_token_length)
            neg_text = self.tokenizer.decode(sampled_token_ids, skip_special_tokens=True)
            return neg_text

        elif self.neg_strategy == "vocab_rag":
            anchor_emb = self._compute_embeddings([anchor_text])[0].reshape(1, -1)
            faiss.normalize_L2(anchor_emb)
            D, I = self.vocab_index.search(anchor_emb, 200)
            candidate_token_ids = [self.vocab_indices[i] for i in I[0]]
            sampled_token_ids = random.choices(candidate_token_ids, k=self.neg_token_length)
            neg_text = self.tokenizer.decode(sampled_token_ids, skip_special_tokens=True)
            return neg_text

        elif self.neg_strategy == "sample_rag":
            anchor_emb = self._compute_embeddings([anchor_text])[0].reshape(1, -1)
            faiss.normalize_L2(anchor_emb)
            D, I = self.index.search(anchor_emb, 10)
            for idx in I[0]:
                if idx != anchor_idx:
                    candidate_sentences = self.data[idx].get("reasoning_sentences", [])
                    if candidate_sentences:
                        return random.choice(candidate_sentences)
            return random.choice(["[NO_VALID_NEGATIVE]"])

        elif self.neg_strategy == "qa_sampling":
            question = self.data[anchor_idx]["question"]
            answer = self.data[anchor_idx]["answer"]
            text = question + " " + answer
            token_ids = self.tokenizer.encode(text, add_special_tokens=False)
            sampled_token_ids = random.choices(token_ids, k=self.neg_token_length)
            neg_text = self.tokenizer.decode(sampled_token_ids, skip_special_tokens=True)
            return neg_text

        elif self.neg_strategy == "sentence_token_sampling":
            reasoning_sentences = self.data[anchor_idx].get("reasoning_sentences", [])
            if not reasoning_sentences:
                return random.choice(["[NO_REASONING]"])
            all_text = " ".join(reasoning_sentences)
            token_ids = self.tokenizer.encode(all_text, add_special_tokens=False)
            sampled_token_ids = random.choices(token_ids, k=self.neg_token_length)
            neg_text = self.tokenizer.decode(sampled_token_ids, skip_special_tokens=True)
            return neg_text
        
        elif self.neg_strategy == "external_rag":
            anchor_emb = self._compute_embeddings([anchor_text])[0].reshape(1, -1)
            faiss.normalize_L2(anchor_emb)
            D, I = self.external_index.search(anchor_emb, 10)
            for idx in I[0]:
                if idx < len(self.external_texts):
                    full_text = self.external_texts[idx]
                    token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)

                    if len(token_ids) < self.neg_token_length:
                        return self.tokenizer.decode(token_ids, skip_special_tokens=True)

                    start_idx = random.randint(0, len(token_ids) - self.neg_token_length)
                    sampled_ids = token_ids[start_idx : start_idx + self.neg_token_length]

                    neg_text = self.tokenizer.decode(sampled_ids, skip_special_tokens=True)

                    # return_text = full_text + "_><split><_" + neg_text
                    return_text = neg_text
                    return return_text
            return "[NO_VALID_EXTERNAL_NEGATIVE]"
        
        else:
            raise ValueError(f"Unknown neg_strategy: {self.neg_strategy}")

    def _batch_sample_negative_text(self, anchor_texts: List[str], anchor_indices: List[int]) -> List[str]:

        if self.neg_strategy in ["random_vocab", "qa_sampling", "sentence_token_sampling"]:
            neg_texts = []
            for anchor_text, anchor_idx in zip(anchor_texts, anchor_indices):
                neg_texts.append(self._sample_negative_text(anchor_text, anchor_idx))
            return neg_texts

        elif self.neg_strategy in ["vocab_rag", "sample_rag", "external_rag"]:

            anchor_embs = self._compute_embeddings(anchor_texts)
            faiss.normalize_L2(anchor_embs)
            neg_texts = []

            if self.neg_strategy == "vocab_rag":
                D, I = self.vocab_index.search(anchor_embs, 200)
                for neighbors in I:
                    candidate_token_ids = [self.vocab_indices[i] for i in neighbors]
                    sampled_token_ids = random.choices(candidate_token_ids, k=self.neg_token_length)
                    neg_text = self.tokenizer.decode(sampled_token_ids, skip_special_tokens=True)
                    neg_texts.append(neg_text)

            elif self.neg_strategy == "sample_rag":
                D, I = self.index.search(anchor_embs, 10)
                for i, neighbors in enumerate(I):
                    anchor_idx = anchor_indices[i]
                    found = False
                    for idx in neighbors:
                        if idx != anchor_idx:
                            candidate_sentences = self.data[idx].get("reasoning_sentences", [])
                            if candidate_sentences:
                                neg_texts.append(random.choice(candidate_sentences))
                                found = True
                                break
                    if not found:
                        neg_texts.append("[NO_VALID_NEGATIVE]")
            elif self.neg_strategy == "external_rag":
                anchor_embs = self._compute_embeddings(anchor_texts)
                faiss.normalize_L2(anchor_embs)
                D, I = self.external_index.search(anchor_embs, 10)
                neg_texts = []

                for neighbors in I:
                    found = False
                    for idx in neighbors:
                        if idx < len(self.external_texts):
                            full_text = self.external_texts[idx]
                            token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)

                            if len(token_ids) < self.neg_token_length:
                                sampled_ids = token_ids
                            else:
                                start_idx = random.randint(0, len(token_ids) - self.neg_token_length)
                                sampled_ids = token_ids[start_idx : start_idx + self.neg_token_length]

                            neg_text = self.tokenizer.decode(sampled_ids, skip_special_tokens=True)

                            # return_text = full_text + "_><split><_" + neg_text
                            return_text = neg_text

                            neg_texts.append(return_text)
                            found = True
                            break

                    if not found:
                        neg_texts.append("[NO_VALID_EXTERNAL_NEGATIVE]")

            return neg_texts

        else:
            raise ValueError(f"Unknown neg_strategy: {self.neg_strategy}")

    def build_and_save(self, output_path: str, mode: str):
        all_samples = []
        skipped = 0
        pos_count = 0
        neg_count = 0

        for idx in tqdm(range(len(self.data)), desc=f"Building {mode} dataset"):
            item = self.data[idx]
            reasoning_sentences = item.get("reasoning_sentences", [])

            if self.anchor_source == "answer":
                anchor_text = item["answer"]
            elif self.anchor_source == "question":
                anchor_text = item["question"]
            elif self.anchor_source == "reasoning_random":
                if reasoning_sentences:
                    anchor_text = random.choice(reasoning_sentences)
                else:
                    anchor_text = item["answer"]
            else:
                raise ValueError(f"Unknown anchor_source {self.anchor_source}")

            answer_text = item["answer"]

            if not reasoning_sentences:
                skipped += 1
                continue

            if mode == "pos":
                for pos_text in reasoning_sentences:
                    all_samples.append({
                        "answer": answer_text,
                        "reason": pos_text,
                        "label": 1
                    })
                pos_count += len(reasoning_sentences)

            elif mode == "neg":
            #     for _ in reasoning_sentences:
            #         neg_reasoning = self._sample_negative_text(anchor_text=anchor_text, anchor_idx=idx)
            #         all_samples.append({
            #             "answer": answer_text,
            #             "reason": neg_reasoning,
            #             "label": 0
            #         })
            #     neg_count += len(reasoning_sentences)
                neg_num = max(1, len(reasoning_sentences) // 3)
                # neg_num = 1
                # neg_num = len(reasoning_sentences)

                batch_anchor_texts = [anchor_text] * neg_num
                batch_anchor_indices = [idx] * neg_num

                batch_neg_reasonings = self._batch_sample_negative_text(batch_anchor_texts, batch_anchor_indices)

                for neg_reasoning in batch_neg_reasonings:
                    pos_text = random.choice(reasoning_sentences)
                    hard_neg_info = self._construct_hard_negative(pos_text, neg_reasoning)

                    all_samples.append({
                        "answer": answer_text,
                        "reason": hard_neg_info["text"],
                        "label": 0,
                        "reason_type": "hard_mix",
                        "meta": hard_neg_info["meta"]
                    })


                # for neg_reasoning in batch_neg_reasonings:
                #     all_samples.append({
                #         "answer": answer_text,
                #         "reason": neg_reasoning,
                #         "label": 0
                #     })
                neg_count += len(batch_neg_reasonings)

            else:
                raise ValueError(f"Unsupported mode: {mode}, should be 'pos' or 'neg'.")
        output_path = output_path.replace(".json", f"{len(all_samples)}_harder_mixed.json")
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(all_samples, f, indent=2)

        print(f"✅ Saved {len(all_samples)} samples to {output_path}.")
        print(f"🔵 Positive samples: {pos_count}")
        print(f"🔵 Negative samples: {neg_count}")
        print(f"⚠️  Skipped {skipped} items due to empty reasoning_sentences.")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_json", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_name", type=str, default="sentence-transformers/all-MiniLM-L6-v2")
    parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased")
    parser.add_argument("--mode", type=str, choices=["pos", "neg"], required=True)
    parser.add_argument("--neg_token_length", type=int, default=20)
    parser.add_argument("--neg_strategy", type=str, default="random_vocab", choices=[
        "random_vocab", "vocab_rag", "sample_rag", "qa_sampling", "sentence_token_sampling"
    ])
    parser.add_argument("--anchor_source", type=str, default="answer", choices=["answer", "question", "reasoning_random"])
    parser.add_argument("--cache_dir", type=str, default="./cache", help="Directory to cache embeddings and indices")
    args = parser.parse_args()

    with open(args.input_json, "r", encoding="utf-8") as f:
        data = json.load(f)
    print(f"🔵 Loaded {len(data)} samples from {args.input_json}")

    embedder = SentenceTransformer(args.model_name)
    print(f"🔵 Loaded embedding model: {args.model_name}")

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
    vocab_indices = list(range(tokenizer.vocab_size))
    print(f"🔵 Loaded tokenizer: {args.tokenizer_name}")

    preprocessor = FaissRAGReasoningPreprocessor(
        data, embedder, tokenizer, vocab_indices,
        neg_token_length=args.neg_token_length,
        neg_strategy=args.neg_strategy,
        anchor_source=args.anchor_source,
        cache_dir=args.cache_dir
    )
    preprocessor.build_and_save(args.output_path, args.mode)

if __name__ == "__main__":
    main()
